import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.ticker import MultipleLocator


def plot_box_results(box_model_prefix, max_fi):
    exp_bm_mag_df = pd.read_csv('csv/dewey_2021_bm_mag.csv', index_col= ['freq (kHz)']).sort_index()
    exp_bm_phase_df = pd.read_csv('csv/dewey_2021_bm_phase.csv',index_col= ['freq (kHz)']).sort_index()
    
    me_gain_df = pd.read_csv('csv/dong_2013_mag.csv', index_col= ['freq (kHz)']).sort_index()
    me_phase_df = pd.read_csv('csv/dong_2013_phase.csv', index_col= ['freq (kHz)']).sort_index()
    
    csv_poi_results = 'csv/'+box_model_prefix+'_POI_bm_and_ooc_results.csv'
    comsol_POI_df = pd.read_csv(csv_poi_results, index_col= ['freq']).sort_index()

    fig, ax = plt.subplots(2, sharex='col', sharey='row', figsize = [4,4.5])
    
    # experiment 1 dewey 2021
    ec_pressure_db = 90
    ec_pressure_Pa = 2e-5*10**(ec_pressure_db/20)
    # ec_pressure_db = 20*np.log10(ec_pressure_Pa/p0)
    # 10**(ec_pressure_db/20)*p0 = ec_pressure_Pa
    
    exp_bm_mag_f = exp_bm_mag_df.index
    exp_bm_mag = exp_bm_mag_df.loc[:, 'mag (nm)']
    exp_bm_mag_nm_per_Pa = exp_bm_mag/ec_pressure_Pa
    exp_bm_db = 20*np.log10(exp_bm_mag_nm_per_Pa)
    exp_bm_phase_f = exp_bm_phase_df.index
    exp_bm_phase = exp_bm_phase_df.loc[:, 'phase (cycles)']

    # dong et al. 2013
    me_mag_f = me_gain_df.index
    me_mag = me_gain_df.loc[:, 'mag (mm/s/Pa)']
    me_phase_f = me_phase_df.index
    me_phase = me_phase_df.loc[:, 'phase (cycles)']

    comsol_f_kHz = comsol_POI_df.index/1000
    comsol_bm_nm_per_mm_per_s = comsol_POI_df.loc[:, 'bm_z_mag_nm'] # note that u0 = 1 mm/s, so this is nm/mm/s as well.
    comsol_bm_nm_per_Pa = comsol_bm_nm_per_mm_per_s*np.interp(comsol_f_kHz, me_mag_f, me_mag) # nm/(mm/s)*mm/s/Pa = nm/Pa
    comsol_bm_db = 20*np.log10(comsol_bm_nm_per_Pa)
    comsol_bm_phase = np.asarray(comsol_POI_df.loc[:, 'bm_z_phase']/2/np.pi)
    comsol_bm_phase_re_EC = comsol_bm_phase + np.interp(comsol_f_kHz, me_phase_f, me_phase)

    while comsol_bm_phase_re_EC[0] < - 0.5:
         comsol_bm_phase_re_EC = comsol_bm_phase_re_EC + 1
    
    while comsol_bm_phase_re_EC[0] > 0.5:
         comsol_bm_phase_re_EC = comsol_bm_phase_re_EC - 1

    
    ax[0].plot(exp_bm_mag_f, exp_bm_db, marker = '^', label = 'exp Dewey et al. 2021', color = 'tab:blue')
    if max_fi == 0:
        ax[0].plot(comsol_f_kHz, comsol_bm_db, '--', marker = 'o', label = 'model results', color = 'tab:green')
    else:
        ax[0].plot(comsol_f_kHz[:max_fi], comsol_bm_db[:max_fi], '--', marker = 'o', label = 'model results', color = 'tab:green')
    ax[1].plot(exp_bm_phase_f, exp_bm_phase, marker = '^', label = 'OCT data Dewey et al. 2021', color = 'tab:blue')
    ax[1].plot(comsol_f_kHz[:max_fi], comsol_bm_phase_re_EC[:max_fi], '--', marker = 'o', label = 'model results', color = 'tab:green')
    ax[0].set_xlim(0,12)
    ax[0].set_ylim(-10,40)

    ax[1].set_xlim(0,12)
    ax[1].set_ylim(-4, 1)
    ax[1].set_xlabel('frequency (kHz)')
    ax[0].set_ylabel('BM mag (dB re 1 nm/Pa at EC)')
    ax[1].set_ylabel('BM phase (cycle re EC)')
    for axx in ax:
        axx.spines['right'].set_visible(False)
        axx.spines['top'].set_visible(False)

    ax[1].legend()
    fig.tight_layout()
    plt.savefig('manufigs/'+ box_model_prefix+'_POI_dewey2021.pdf', transparent=True)
    plt.show()


def plot_freq_map(box_model_prefix):
    nankali_df = pd.read_csv('csv/nankali_2022_map.csv', index_col= ['x (%)']).sort_index()
    nankali_x = nankali_df.index*5.8
    nankali_freq = nankali_df.loc[:, 'frequency (kHz)']*1000

    muller_x = np.arange(0, 5.8, 0.1)
    muller_freq = 10**((156.5 -muller_x/5.8*100)/82.5)*1000

    csv_poi_results = 'csv/'+box_model_prefix+'_POI_bm_and_ooc_results.csv'
    comsol_POI_df = pd.read_csv(csv_poi_results, index_col= ['freq']).sort_index()
    max_location_df = pd.read_csv('csv/'+box_model_prefix+'_max_location_results.csv')
    freqs = comsol_POI_df.index
    max_locations = max_location_df.loc[:, 'max_location']

    fig, ax = plt.subplots(1, sharex='col', sharey='row', figsize = [3.5,4])

    ax.plot(muller_x, muller_freq, label = 'Muller et al. 2005', color = 'tab:blue')
    ax.plot(muller_x, muller_freq/np.sqrt(2), label = 'Shifted Muller et al.', color = 'tab:orange')
    ax.plot(nankali_x, nankali_freq, marker = '^', label = 'Nankali et al. 2022', linestyle = 'None', color = 'tab:purple')
    ax.plot(max_locations, freqs, label = 'model results', marker = 's', color = 'tab:green')
    ax.set_xlim(0,5.8)
    ax.legend()
    ax.set_yscale('log')
    ax.set_xlabel('longitudinal location (mm)')
    ax.set_ylabel('best frequency (kHz)')
    ax.set_yticks([4000, 5000, 10000, 30000, 50000])
    ax.set_yticklabels(['4','5', '10','30','50'])
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    pdf_results = 'manufigs/'+ box_model_prefix+'_map_no_data.pdf'
    fig.tight_layout()
    plt.savefig(pdf_results, transparent=True)
    plt.show()

def plot_bm_mag_phase(box_model_prefix):
    bm_df = pd.read_csv('csv/'+box_model_prefix+'_bm_results.csv', index_col = [0], header = [0, 1])
    freqs = sorted([int(i) for i in set(bm_df.columns.get_level_values(0))])

    me_gain_df = pd.read_csv('csv/dong_2013_mag.csv', index_col= ['freq (kHz)']).sort_index()
    me_phase_df = pd.read_csv('csv/dong_2013_phase.csv', index_col= ['freq (kHz)']).sort_index()

    me_mag_f = me_gain_df.index
    me_mag = me_gain_df.loc[:, 'mag (mm/s/Pa)']
    me_phase_f = me_phase_df.index
    me_phase = me_phase_df.loc[:, 'phase (cycles)']

    print(freqs)
    print(me_mag_f)

    X = np.asarray(bm_df.index)

    fig, ax = plt.subplots(2, sharex='col', sharey='row', figsize = [4,4.5])

    fi = 0
    for freqi in freqs:
        mag = np.asarray(bm_df.loc[:, (str(freqi), 'bm_z_mag_nm')])
        phase = np.asarray(bm_df.loc[:, (str(freqi), 'bm_z_phase')])/2/np.pi

        mag_per_Pec = mag*np.interp(freqi/1000, me_mag_f, me_mag) # nm/(mm/s)*mm/s/Pa = nm/Pa
        phase_re_EC = phase + np.interp(freqi/1000, me_phase_f, me_phase)

        for i in range(len(mag)):
            if mag[i] <= 0.1:
                mag[i] = 0
                phase[i] = phase[i-1]

        ax[0].plot(X, 20*np.log10(mag_per_Pec))
        ax[0].set_ylabel('mag (dB re 1 nm/Pa at EC)')
        ax[0].set_ylim(-10, 40)
        
        ax[1].plot(X, phase_re_EC)
        ax[1].set_ylabel('phase (cycle re EC)')
        ax[1].set_xlabel('longitudinal location (mm)')
        ax[1].set_xlim(0,5.8)
        ax[1].set_ylim(-4, 1)
    
        
    
    for axx in ax:
        axx.spines['right'].set_visible(False)
        axx.spines['top'].set_visible(False)

    fig.tight_layout()
    pdf_results = 'manufigs/'+ box_model_prefix+'_mid_line_results.pdf'
    plt.savefig(pdf_results, transparent=True)
    
    plt.show()

def plot_k_and_p_diff(box_model_prefix):
    csv_poi_results = 'csv/'+box_model_prefix+'_POI_bm_and_ooc_results.csv'
    comsol_POI_df = pd.read_csv(csv_poi_results, index_col= ['freq']).sort_index()
    comsol_f_kHz = comsol_POI_df.index/1000
    comsol_k = comsol_POI_df.loc[:, 'bm_k']/1000/2/np.pi # from rad/m to cycle/mm
    
    p_diff_df = pd.read_csv('csv/p_diff_box.csv', index_col = ['freq']).sort_index()
    p_diff_mag = np.asarray(p_diff_df.loc[:, 'p_diff_mag'])/10 # reduce by a factor of 10

    fig, ax = plt.subplots(2, sharey='row', figsize = [3.5,5.5])
        
    max_fi = 7

    ax[0].plot(comsol_f_kHz[:max_fi], comsol_k[:max_fi], '--', marker = 's', color = 'tab:green')
    ax[0].set_ylim(0, 2)
    ax[0].set_xlabel('frequency (kHz)')
    ax[0].set_ylabel('wavenumber (cycle/mm)')
    
    ax[1].plot(comsol_f_kHz[:max_fi], p_diff_mag[:max_fi], '--', marker = 's', color = 'tab:green')
    ax[1].set_xlim(4-0.2,7+0.2)
    ax[1].set_ylim(0, 3)    
    ax[1].set_xlabel('frequency (kHz)')
    ax[1].set_ylabel('pressure difference (Pa)')
    
    ax[0].spines['right'].set_visible(False)
    ax[0].spines['top'].set_visible(False)
    ax[1].spines['right'].set_visible(False)
    ax[1].spines['top'].set_visible(False)

    for axi in ax:
        axi.xaxis.set_major_locator(MultipleLocator(1))
        axi.xaxis.set_major_formatter('{x:.0f}')
        axi.xaxis.set_minor_locator(MultipleLocator(0.5))
        


    fig.tight_layout()
    plt.savefig('manufigs/'+ box_model_prefix+'_k_f_p_diff.pdf', transparent=True)
    plt.show()

plot_manufig_fig1_box_model.plot_box_results('mouse_box_model_with_OoC_run196f', 10)
plot_manufig_fig1_box_model.plot_freq_map('mouse_box_model_with_OoC_run196f')
plot_manufig_fig1_box_model.plot_bm_mag_phase('mouse_box_model_with_OoC_run196f')
plot_manufig_fig1_box_model.plot_k_and_p_diff('mouse_box_model_with_OoC_run196f')